"""
@author : linrh
@homepage : https://gitee.com/linrh-DUT
@version: 1.0.0
@when : 2023/5/29
@file: data.py.py
"""
import os
import torch
from torchvision import datasets, transforms
from conf import *

# Configure data loader
os.makedirs("data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=opt.batch_size,
    shuffle=True,
)